""" DiffGro Implementation """
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import time
import cv2
import gym
import numpy as np
import clip

from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim

from diffgro.environments import task_to_object
from diffgro.environments.collect_dataset import get_skill_embed
from diffgro.common.models.utils import cos_sim
from diffgro.diffgro.planner import DiffGroPlanner
from diffgro.diffgro.functions import guide_fn_dict, _loss_txt, _manual_loss_fn
from diffgro.utils import llm
from diffgro.utils import *


class DiffGro:
    guide_methods = ['blank', 'test', 'manual', 'llm']
    def __init__(
        self,
        env: gym.Env,
        planner: DiffGroPlanner,
        history: int = None,    # history to stack
        guide: str = None,      # guide function
        guide_pt: str = None,   # prompt or path for llm guidance
        multimodal: bool = False, # text-vision
        validate: bool = True,  # llm validation
        delta: float = 1.0,     # scale for guidance
        save_path: str = './results',
        verbose: bool = False,
        debug: bool = False,
    ):
        self.env = env
        self.planner = planner.policy
        self.history = history
        # guidance
        if guide is not None:
            assert guide in DiffGro.guide_methods, f"Guide method {guide} should be in {DiffGro.guide_methods}"
        self.guide = guide
        self.guide_fn = guide_fn_dict[guide] if guide is not None else None
        self.guide_pt = guide_pt # context
        self.context_info = None
        self.multimodal = multimodal
        self.validate = validate
        self.delta = delta
        # misc
        self.save_path = save_path
        self.verbose = verbose
        self.debug = debug
        self._setup()

    def _setup(self) -> None:
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.act_dim = get_act_dim(self.env.action_space)
        self.horizon = self.planner.horizon 
        
        # history settting
        if self.history is None:
            self.history = int(self.horizon / 2)
        print_b(f"[diffgro] History stack is set as {self.history}")

        # task embedding
        self.task = get_skill_embed(None, self.env.env_name).reshape(1, -1)
        if self.env.domain_name == 'metaworld_complex':
            self.skill = [get_skill_embed(None, task).reshape(1, -1) for task in self.env.full_task_list]
    
    def _setup_guide(self) -> None:
        # guidance settings 
        self.n_guide_steps = 1 
        if self.guide == 'test':
            self.loss_fn = [self.guide_fn[self.guide_pt] for _ in range(self.env.task_num)]
        if self.guide == 'blank': # no guidance only for evaluating contexts
            self.n_guide_steps = 0 
        if self.guide == 'manual':
            self.loss_fn, self.guide_pt, self.loss_pt = [], [], []
            for context in self.context_info:
                context_dict = {"context_type": context[2], "context_target": context[3]}
                self.loss_pt.append(context[0])
                self.guide_pt.append(context[1])
                loss_fn, _ = self.guide_fn(**context_dict)
                self.loss_fn.append(loss_fn)
                self.delta = context[4]
        if self.guide  == 'llm':
            self.loss_fn, self.guide_pt, self.loss_pt, self.prev_answer  = [], [], [], []
            for context in self.context_info:
                self.loss_pt.append(context[0])
                self.guide_pt.append(context[1])
                if not self.multimodal:
                    loss_fn, prev_answer = self._generate_loss_fn(self.loss_pt[-1])
                else:
                    loss_fn, prev_answer = self._generate_loss_fn(self.loss_pt[-1])
                self.loss_fn.append(loss_fn)
                self.prev_answer.append(prev_answer)
                self.delta = context[4]
        
        print_b(f"[diffgro] guidance function is '{self.guide}' and scale is '{self.delta}'")
        print_b(f"[diffgro] the guide prompt is '{self.guide_pt}'")

    def reset(self) -> None:
        self.t, self.h = 0, 0
        self.obs_stack = np.zeros((1, self.horizon, self.obs_dim))
        self.act_stack = np.zeros((1, self.horizon, self.act_dim))
        self.obj_contact = False
        # self._setup_guide()

    def predict(self, obs: np.ndarray, deterministic: bool = True):
        np.set_printoptions(precision=4, suppress=True)

        # add batch dimension
        obs = obs.reshape((-1,) + obs.shape)
        
        # 1. inference prior
        self.ctx = self.predict_pri(obs)

        # 2-1. conditioning
        self.obs_stack[0][self.h] = obs[0]
        cond = np.concatenate((self.obs_stack, self.act_stack), axis=-1)
        # 2-2. masking
        mask_obs = np.concatenate((np.ones((1, self.h + 1, self.obs_dim)), np.zeros((1, self.horizon - self.h - 1, self.obs_dim))), axis=1)
        if self.h == 0: mask_act = np.zeros((1, self.horizon, self.act_dim))
        else: mask_act = np.concatenate((np.ones((1, self.h, self.act_dim)), np.zeros((1, self.horizon - self.h, self.act_dim))), axis=1)
        mask = np.concatenate((mask_obs, mask_act), axis=-1)

        # 3. image inference
        if self.multimodal and self.guide == 'llm':
            if (self.env.domain_name == 'metaworld') and (self.t % 40 == 0):
                if not self.obj_contact:
                    self.obj_contact = self._multimodal_loss_fn()
            if (self.env.domain_name == 'metaworld_complex') and (self.t % 40 == 0):
                if self.env.task_list[self.env.env.env.mode] == 'drawer':
                    if not self.obj_contact:
                        self.obj_contact = self._multimodal_loss_fn()
            if (self.env.domain_name == 'metaworld_complex'):
                if self.env.task_list[self.env.env.env.mode] != 'drawer':
                    self.obj_contact = False
                if self.t > 400:
                    self.validate = False

        # 3. inference planner
        act = self.predict_act(cond, mask, self.ctx)
        self.act_stack[0][self.h] = act

        self.t += 1
        self.h += 1

        # 4. stacking
        if self.history != 0:
            if self.h == (self.horizon - 1): # if history, we reset for every history step # discrad last state
                self.obs_stack = np.concatenate((self.obs_stack[:,-self.history-1:-1,:], np.zeros((1, self.horizon - self.history, self.obs_dim))), axis=1)
                self.act_stack = np.concatenate((self.act_stack[:,-self.history-1:-1,:], np.zeros((1, self.horizon - self.history, self.act_dim))), axis=1)
                self.h = self.history - 1 
        else: # if no history, we reset for every state
            self.reset()

        act = np.array(act.copy())
        return act, None, {"h": self.t, "guided": self.guided}
    
    def predict_pri(self, obs: np.ndarray) -> np.ndarray:
        # 1. inference prior
        task, skill = self.task, None
        if self.env.domain_name == 'metaworld_complex':
            task, skill = self.task, self.skill[self.env.success_count]

        if self.h == 0:
            ctx = self.planner._predict_pri(obs, task, skill, deterministic=True) 
        elif self.history != 0 and self.h == (self.history - 1):
            ctx = self.planner._predict_pri(self.obs_stack[0][0].reshape(-1, self.obs_dim), task, skill, deterministic=True)
        else:
            ctx = self.ctx
        return ctx

    def predict_act_without_guide(self, cond, mask, ctx):
        self.guided = False
        plan, info = self.planner._predict_act(
            cond, mask, ctx, None, delta=None, guide_fn=None, n_guide_steps=self.n_guide_steps, deterministic=True, verbose=self.verbose)
        act = plan[:,self.h,-self.act_dim:][0]
        return plan, act

    def predict_act_with_guide(self, cond, mask, ctx, loss_fn=None):
        self.guided = True
        if loss_fn is None:
            loss_fn = self.loss_fn[0]
            if self.env.domain_name == 'metaworld_complex':
                try:
                    loss_fn = self.loss_fn[self.env.success_count]
                except:
                    loss_fn = None
                    self.guided = False
        plan, info = self.planner._predict_act(
            cond, mask, ctx, None, delta=self.delta, guide_fn=loss_fn, n_guide_steps=self.n_guide_steps, deterministic=True, verbose=self.verbose)
        act = plan[:,self.h,-self.act_dim:][0]
        return plan, act

    def predict_act(self, cond, mask, ctx):
        if (self.guide_fn is None) or (self.t < self.history):
            plan, act = self.predict_act_without_guide(cond, mask, ctx)
        elif self.guide_fn == 'blank':
            plan, act = self.predict_act_without_guide(cond, mask, ctx)
        elif self.guide == 'test' or self.guide == 'manual':
            plan, act = self.predict_act_with_guide(cond, mask, ctx)
        elif self.guide == 'llm':
            if self.t == self.history: # regenerate code
                print_y(f"<< code verification >>")
                plan_origin, act_origin = self.predict_act_without_guide(cond, mask, ctx)

                # loop for everty contexts
                act = [0, 0, 0, 0]
                for i in range(len(self.loss_fn)):
                    retry = 0
                    while True:
                        retry += 1
                        if retry == 10:
                            print_r("We cannot guide this context !!!")
                            self.guide_fn = None
                            break

                        try:
                            plan_guided, act_guided = self.predict_act_with_guide(cond, mask, ctx, loss_fn=self.loss_fn[i])
                        except:
                            if not self.multimodal:
                                self.loss_fn[i] = self._regenerate_loss_fn(self.loss_pt[i], self.prev_answer[i])
                            else:
                                self.loss_fn[i] = self._regenerate_loss_fn(self.loss_pt[i], self.prev_answer[i])
                            continue
                    
                        act[i] = act_guided
                        break

                        """
                        flag = self._validate_loss_fn(plan_origin[:,:,-self.act_dim:-1], plan_guided[:,:,-self.act_dim:-1])
                        if flag:
                            act = act_guided
                            break
                        else:
                            self.loss_fn, self.prev_answer = self._regenerate_loss_fn()

                        if retry == 10: 
                            print_r("We cannot guide this context !!!")
                            # self.guide_fn = None
                            # break
                        """
                act = act[0]
            elif self.multimodal and (not self.obj_contact):
                    plan, act = self.predict_act_without_guide(cond, mask, ctx)
            else:
                #print_y(f"<< n_guide_step verification at {self.t}>>")
                if self.t % (self.horizon) == 0 and self.validate:
                    plan_origin, act_origin = self.predict_act_without_guide(cond, mask, ctx)
                    self.n_guide_steps = 1
                    max_guide_step = 4
                    if self.env.domain_name == 'metaworld_complex':
                        max_guide_step = 4
                    
                    for i in range(max_guide_step):
                        plan_guided, act_guided = self.predict_act_with_guide(cond, mask, ctx)
                        if self.env.domain_name == 'metaworld':
                            flag = self._validate_loss_fn(
                                self.guide_pt[0], plan_origin[:,:,-self.act_dim:-1], plan_guided[:,:,-self.act_dim:-1])

                        else:
                            flag = self._validate_loss_fn(
                                self.guide_pt[self.env.success_count], plan_origin[:,:,-self.act_dim:-1], plan_guided[:,:,-self.act_dim:-1])
                        if flag:
                            break
                        else:
                            self.n_guide_steps += 1 
                    act = act_guided
                else:
                    plan, act = self.predict_act_with_guide(cond, mask, ctx)
        else:
            raise NotImplementedError

        return act  

    def _generate_loss_fn(self, guide_pt: str) -> Tuple[Callable, str]:
        example = "Example of requirement and generated code pair is follows. User requirement: 'Generate a loss function such that the robot arm to move at speed faster than 0.37 but slower than 0.39. Generated code: 'def _loss_fn(x, obs_dim): act = x[:,:,obs_dim:-1] speed = jnp.linalg.norm(act, axis=-1) min_speed = 0.37 max_speed = 0.39 loss = jnp.maximum(speed - max_speed, 0) + jnp.maximum(min_speed - speed, 0) return jnp.mean(loss)'"

        prompt1 = f"The shape of the action sequence of robot arm is (B, H, 4) where B is the batch size,  H is the number of sequence, and 4 represents the (x, y, z, gripper on/off) in corresponding agent coordinate of a robot arm. Explain the configuration of robot arm's action."
        prompt2 = f"The robot arm should  satisfy the given user requirement. The user requirement is to {guide_pt}. Note that speed of the robot arm is determined by the L2 norm of the actions. In orther to satisfy the user requirement, What should be considered?"
        prompt3 = f"Now, generate a loss function that guides the trajectory to satisfy the user requirement. Generated python loss function should follow the following format: 'def _loss_fn(x, obs_dim): act = x[:,:,obs_dim:-1] return loss'. act is a numpy array representing the action sequences. Use jnp instead of np and the generated loss function should be complied in just in time. Furthermore, _loss_fn should not call other functions. The loss function should return the loss for given user's preference. {guide_pt}."
        prompts = [prompt1, prompt2, prompt3]

        # chain of thought prompting
        print_r(f"def _generate")
        start = time.time()
        answer, prompt_cat = llm.chain_of_thought(prompts)
        end = time.time()
        print(f"Time for gpt: {end - start}")
        print_b(f"{prompt_cat}")

        # loss function generation
        loss_fn, code = _loss_txt(answer)

        # save the code to text file
        with open(os.path.join(self.save_path, 'code.txt'), 'a') as f:
            f.write("def _generate ...\n")
            f.write(answer+'\n')
        # Saving codes
        with open(os.path.join(self.save_path, 'evaluation.txt'), 'a') as f:
            f.write('```\n')
            f.write(code+'\n')
            f.write('```\n')
        return loss_fn, prompt_cat

    # regenerate code based on the previously failed code
    def _regenerate_loss_fn(self, guide_pt: str, prev_answer: str) -> Tuple[Callable, str]:
        prompt = f"The previously generated code made an error. Regenerate the Python code. The loss function should follow the following format: 'def _loss_fn(x, obs_dim): act = x[:,:,obs_dim:-1] return jnp.mean(loss)'. Note that the user requirement was {guide_pt}."
        prompt_cat = prev_answer + f"Q: {prompt}\n"

        # setting guidance function
        print_r(f"def _regenerate at {self.t}")
        start = time.time()
        answer = llm.chatgpt(prompt_cat, temperature=0.5)
        end = time.time()
        print(f"Time for gpt: {end - start}")
        #print_b(f"Answer: {answer}")
        
        # loss function generation
        loss_fn, code = _loss_txt(answer)

        # save the code to text file
        with open(os.path.join(self.save_path, 'code.txt'), 'a') as f:
            f.write("def _regenerate ...\n")
            f.write(answer+'\n')
        # Saving codes
        with open(os.path.join(self.save_path, 'evaluation.txt'), 'a') as f:
            f.write('```\n')
            f.write(code+'\n')
            f.write('```\n')
        return loss_fn
    
    def _validate_loss_fn(self, guide_pt: str, act_origin: np.ndarray, act_guided: np.ndarray) -> bool:
        prompt1 = f"The original action was {act_origin} and modified action is now {act_guided}. The shape of the action sequence of robot arm is (B, H, 3) where B is the batch size,  H is the number of sequence, and 3 represents the (x, y, z) in corresponding agent coordinate of a robot arm. The user's requirement is '{guide_pt}'. Note that speed of the robot arm is determined by the L2 norm of the actions. Did the modified action satisfies the user requirement? Answer step by step."
        prompt2 = f"According to the answer above, did the modified action meet's the user's requirement, Answer with 'Yes' or 'No'."
        prompts = [prompt1, prompt2]

        # chain of thought prompting
        print_r(f"def _validate at {self.t} with {self.n_guide_steps}")
        start = time.time()
        answer, prompt_cat = llm.chain_of_thought(prompts)
        end = time.time()
        print(f"Time for gpt: {end - start}")
        #print_b(f"Question: {prompt_cat}")
        #print_b(f"Answer: {answer}")

        # save the code to text file
        # with open(os.path.join(self.save_path, 'code.txt'), 'a') as f:
        #    f.write("def _validate ...\n")
        #    f.write(answer+'\n')
         

        # if meets user-defined context, done 
        if 'Yes' in answer:
            return True
        else:
            return False
            #prompt3 = prompt_cat + f"Then, does the action should modified a lot to satisfy the user requirement? Answer step by step."
            #prompt4 = f"According to the answer above, does action should modifeid a lot to satisfy the user requirement? Answer with 'Yes' or 'No.'"
            #prompts = [prompt3, prompt4]
            #answer, prompt_cat = llm.chain_of_thought(prompts)
            #if 'Yes' in answer:
            #    True
            #else:
            #    return False

    def image2binary(self, camera_pos):
        frame = self.env.render(camera_name=camera_pos)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        make_dir('./results/img')
        path = f'./results/img/{self.t}.png'
        cv2.imwrite(path, frame)

    def _multimodal_loss_fn(self) -> bool:
        object_name, camera_pos, task_prompt, query_prompt = task_to_object(self.env.domain_name, self.env.env_name)
        self.image2binary(camera_pos)
        
        print_r(f"def _determine (1) ... at {self.t}")
        if self.t == 0:
            image = f'./results/img/0.png' 
            prompt = f"{task_prompt}. Given the image, is the robot arm is contacting on the {object_name}? or not contacting on the {object_name}?" 
            answer = llm.chatgpt_vision(image, prompt)
        else:
            image1 = f'./results/img/{self.t-40}.png'
            image2 = f'./results/img/{self.t}.png'
            #prompt = f"{task_prompt}. For the second image, is the robot arm's end effector contacting to the {object_name}? or not contacting to the {object_name}? The answer for the first image was {self.multimodal_answer}. Please make an answer step by step, referring to the first image."
            prompt = f"{task_prompt}. For the second image, {query_prompt} The answer for the first image was {self.multimodal_answer}. Please make an answer step by step, referring to the first image."
            answer = llm.chatgpt_vision_multiple(image1, image2, prompt)

        print_b(f"Question: {prompt}")
        print_b(f"Answer: {answer}")

        prompt = f"Q: {prompt}\n, A: {answer}\n" + "According to the answer above, answer with Yes or No."
        print_r("def _determine (2) ... ")
        answer = llm.chatgpt(prompt)
        print_b(f"Answer: {answer}")

        if 'Yes' in answer:
            self.multimodal_answer = "Yes"
            return True
        else:
            self.multimodal_answer = "No"
            return False
